import logging

from . import _confidence
from . import _determinize_lattice_pruned as _dlp
from . import _lattice_functions as _lat_fun

from ._confidence import *
from ._compose_lattice_pruned import *
from ._determinize_lattice_pruned import *
from ._lattice_functions import *
from ._minimize_lattice import *
from ._push_lattice import *

from .. import fstext as _fst
from ..fstext import _api


def sentence_level_confidence(lat):
    """Computes sentence level confidence scores.

    If input is a compact lattice, this function requires that distinct paths in
    `lat` have distinct word sequences; this will automatically be the case if
    `lat` was generated by a decoder, since a deterministic FST has this
    property. If input is a state-level lattice, it is first determinized, but
    this is done in a "smart" way so that only paths needed for this operation
    are generated.

    This function assumes that any acoustic scaling you want to apply,
    has already been applied.

    The output consists of the following. `confidence` is the score difference
    between the best path and the second-best path in the lattice (a positive
    number), or zero if lattice was equivalent to the empty FST (no successful
    paths), or infinity if there was only one path in the lattice. `num_paths`
    is a number in `{0, 1, 2}` saying how many n-best paths (up to two) were
    found. If `num_paths >= 1`, `best_sentence` is the best word-sequence; if
    `num_paths -= 2`, `second_best_sentence` is the second best word-sequence
    (this may be useful for testing whether the two best word sequences are
    somehow equivalent for the task at hand).

    Args
        lat (LatticeVectorFst or CompactLatticeVectorFst): The input lattice.

    Returns:
        Tuple[float, int, List[int], List[int]]: The tuple
        `(confidence, num_paths, best_sentence, second_best_sentence)`.

    Note:
        This function is not the only way to get confidences in Kaldi. This only
        gives you sentence-level (utterance-level) confidence. You can get
        word-by-word confidence within a sentence, along with Minimum Bayes Risk
        decoding. Also confidences estimated using this function are not very
        accurate.
    """
    if isinstance(lat, _fst.CompactLatticeVectorFst):
        return _confidence._sentence_level_confidence_from_compact_lattice(lat)
    else:
        return _confidence._sentence_level_confidence_from_lattice(lat)


def determinize_lattice_phone_pruned(ifst, trans_model, prune,
                                     opts=None, destructive=True):
    """Applies a specialized determinization operation to a lattice.

    Determinizes a raw state-level lattice, keeping only the best output-symbol
    sequence (typically transition ids) for each input-symbol sequence. This
    version does phone insertion when doing a first pass determinization (if
    `opts.phone_determinize == True`), it then removes the inserted phones and
    does a second pass determinization on the word lattice (if
    `opts.word_determinize == True`). It also does pruning as part of the
    determinization algorithm, which is more efficient and prevents blowup.

    Args:
        ifst (LatticeFst): The input lattice.
        trans_model (TransitionModel): The transition model.
        prune (float): The pruning beam.
        opts (DeterminizeLatticePhonePrunedOptions): The options for lattice
            determinization.
        destructive (bool): Whether to use the destructive version of the
            algorithm which mutates input lattice.

    Returns:
        CompactLatticeVectorFst: The output lattice.

    See Also:
        :meth:`determinize_lattice_pruned`

    Note:
        The point of doing first a phone-level determinization pass and then a
        word-level determinization pass is that it allows us to determinize
        deeper lattices without "failing early" and returning a too-small
        lattice due to the max-mem constraint. The result should be the same
        as word-level determinization in general, but for deeper lattices it is
        a bit faster, despite the fact that we now have two passes of
        determinization by default.
    """
    if opts is None:
        opts = DeterminizeLatticePhonePrunedOptions()
    if not destructive or not isinstance(ifst, _api._MutableFstBase):
        ifst = _fst.LatticeVectorFst(ifst)
    ofst = _fst.CompactLatticeVectorFst()
    success = _dlp._determinize_lattice_phone_pruned_wrapper(trans_model, ifst,
                                                             prune, ofst, opts)
    if not success:
        logging.warning(
            "Lattice determinization is terminated early because at least one "
            "of max_mem, max_loop or max_arcs thresholds was reached. If you "
            "want a more detailed log message, rerun this function after "
            "setting verbose level > 0 using kaldi.base.set_verbose_level.")
    return ofst


def determinize_lattice_pruned(ifst, prune, opts=None, compact_out=True):
    """Applies a specialized determinization operation to a lattice.

    Determinizes a raw state-level lattice, keeping only the best output-symbol
    sequence (typically transition ids) for each input-symbol sequence. This
    version does determinization only on the word lattice. The output is
    represented using either sequences of arcs (if `compact_out == False`),
    where all but the first one has an epsilon on the input side, or directly as
    strings using compact lattice weight type (if `compact_out == True`). It
    also does pruning as part of the determinization algorithm, which is more
    efficient and prevents blowup.

    Args:
        ifst (LatticeFst): The input lattice.
        prune (float): The pruning beam.
        opts (DeterminizeLatticePrunedOptions): The options for lattice
            determinization.
        compact_out (bool): Whether to output a compact lattice.

    Returns:
        LatticeVectorFst or CompactLatticeVectorFst: The output lattice.

    See Also:
        :meth:`determinize_lattice_phone_pruned`
    """
    if opts is None:
        opts = DeterminizeLatticePrunedOptions()
    ifst = _fst.LatticeVectorFst(ifst).invert().topsort().arcsort()
    if compact_out:
        ofst = _fst.CompactLatticeVectorFst()
        success = _dlp._determinize_lattice_pruned_to_compact(ifst, prune, ofst,
                                                              opts)
    else:
        ofst = _fst.LatticeVectorFst()
        success = _dlp._determinize_lattice_pruned(ifst, prune, ofst, opts)
    if not success:
        logging.warning(
            "Lattice determinization is terminated early because at least one "
            "of max_mem, max_loop or max_arcs thresholds was reached. If you "
            "want a more detailed log message, rerun this function after "
            "setting verbose level > 0 using kaldi.base.set_verbose_level.")
    return ofst


def lattice_state_times(lat):
    """Extracts lattice state times (in terms of frames).

    Iterates over the states of a topologically sorted lattice and computes
    the corresponding time instances.

    Args:
        lat (LatticeVectorFst or CompactLatticeVectorFst): The input lattice.

    Returns:
        Tuple[int, List[int]]: The number of frames and the state times.

    Note:
        If input is a regular lattice, the number of frames is equal to the
        maximum state time in the lattice. If input is a compact lattice,  the
        number of frames might not be equal to the maximum state time in the
        lattice due to frames in final states.
    """
    if isinstance(lat, _fst.LatticeVectorFst):
        return _lat_fun._lattice_state_times(lat)
    else:
        return _lat_fun._compact_lattice_state_times(lat)


def compute_lattice_alphas_and_betas(lat, viterbi):
    """Computes forward and backward scores for lattice states.

    If `viterbi == True`, computes the Viterbi scores, i.e. forward (alpha) and
    backward (beta) scores are the scores of best paths reaching and leaving
    each state. Otherwise, computes regular forward and backward scores. Note
    that alphas and betas are negated costs. Requires the input lattice to be
    topologically sorted.

    Args:
        lat (LatticeVectorFst or CompactLatticeVectorFst): The input lattice.
        viterbi (bool): Whether to compute Viterbi scores.

    Returns:
        Tuple[float, List[float], List[float]]: The total-prob (or best-path
        prob), the forward (alpha) scores and the backward (beta) scores.
    """
    if isinstance(lat, _fst.LatticeVectorFst):
        return _lat_fun._compute_lattice_alphas_and_betas(lat, viterbi)
    else:
        return _lat_fun._compute_compact_lattice_alphas_and_betas(lat, viterbi)


def top_sort_lattice_if_needed(lat):
    """Topologically sorts the lattice if it is not already sorted.

    Args:
        lat (LatticeVectorFst or CompactLatticeVectorFst): The input lattice.

    Raises:
        RuntimeError: If lattice cannot be topologically sorted.
    """
    if isinstance(lat, _fst.LatticeVectorFst):
        _lat_fun._top_sort_lattice_if_needed(lat)
    else:
        _lat_fun._top_sort_compact_lattice_if_needed(lat)


def prune_lattice(beam, lat):
    """Prunes a lattice.

    Args:
        beam (float): The pruning beam.
        lat (LatticeVectorFst or CompactLatticeVectorFst): The input lattice.

    Raises:
        ValueError: If pruning fails.
    """
    if isinstance(lat, _fst.LatticeVectorFst):
        _lat_fun._prune_lattice(beam, lat)
    else:
        _lat_fun._prune_compact_lattice(beam, lat)


def rescore_lattice(decodable, lat):
    """Adjusts acoustic scores in the lattice.

    This function *adds* the negated scores obtained from the decodable object,
    to the acoustic scores on the arcs. If you want to replace them, you should
    use :meth:`scale_compact_lattice` to first set the acoustic scores to zero.
    The input labels (or the string component of arc weights if the input is a
    compact lattice), are interpreted as transition-ids or whatever other index
    the decodable object expects.

    Args:
        decodable (DecodableInterface): The decodable object.
        lat (LatticeVectorFst or CompactLatticeVectorFst): The input lattice.

    Raises:
        ValueError: If the inputs are not compatible.

    See Also:
        :meth:`rescore_compact_lattice_speedup`
    """
    if isinstance(lat, _fst.LatticeVectorFst):
        _lat_fun._rescore_lattice(decodable, lat)
    else:
        _lat_fun._rescore_compact_lattice(decodable, lat)


def longest_sentence_length_in_lattice(lat):
    """Returns the number of words in the longest sentence in a lattice.

    Args:
        lat (LatticeVectorFst or CompactLatticeVectorFst): The input lattice.

    Returns:
        int: The length of the longest sentence in the lattice.

    """
    if isinstance(lat, _fst.LatticeVectorFst):
        return _lat_fun._longest_sentence_length_in_lattice(lat)
    else:
        return _lat_fun._longest_sentence_length_in_compact_lattice(lat)


__all__ = [name for name in dir()
           if name[0] != '_'
           and not name.endswith('Base')]
